# List of required packages
required_packages <- c("tidyverse", "nnet", "knitr", "kableExtra", "patchwork")

# Install any that are not already installed
new_packages <- required_packages[!(required_packages %in% installed.packages()[, "Package"])]
if (length(new_packages)) {
  install.packages(new_packages)
}

# Load all packages
lapply(required_packages, library, character.only = TRUE)

# Create output directories
dir.create("outputs", showWarnings = FALSE)
dir.create("outputs/tables", showWarnings = FALSE)
dir.create("outputs/figures", showWarnings = FALSE)

# Define party order (Abstainer as reference, but will not be plotted)
party_order <- c("Abstainer/null", "Podemos", "Sumar", "PSOE", "Other", "PP", "SALF", "VOX")
plot_parties <- c("Podemos", "Sumar", "PSOE", "Other", "PP", "SALF", "VOX")  # Parties to show in plots

# Define party colors
party_colors <- c(
  "Podemos" = "#984EA3",        # Purple
  "Sumar" = "#FF69B4",          # Hot Pink
  "PSOE" = "#E41A1C",           # Red
  "VOX" = "#4DAF4A",            # Green
  "PP" = "#377EB8",             # Blue
  "Other" = "#808080",
  "SALF" = "#A6761D"            # Brown
)

# Load and preprocess data
data <- read_csv("merged_survey_data.csv") %>%
  filter(party_pref %in% party_order) %>%
  mutate(
    lr = if_else(survey == "gesop", lr * 2, lr),
    gender = if_else(gender %in% c(NA, "Other"), NA_character_, gender),
    gender = factor(gender, levels = c("Female", "Male")),
    # Create binary education variable (tertiary or less)
    edu_binary = case_when(
      edu_fct == "Tertiary studies" ~ "Tertiary",
      !is.na(edu_fct) ~ "Non-tertiary",
      TRUE ~ NA_character_
    ),
    edu_binary = factor(edu_binary, levels = c("Non-tertiary", "Tertiary")),
    # Set Abstainer as reference level
    party_pref = factor(party_pref, levels = party_order),
    # Ensure survey is a factor for models
    survey = factor(survey)
  ) %>%
  drop_na(gender, age, edu_binary, party_pref)  # Drop NAs in key variables

#==========================================================================================
# FUNCTION TO EXTRACT MODEL FIT STATISTICS
#==========================================================================================

get_model_fit <- function(model, data) {
  # Log-likelihood
  ll <- logLik(model)[1]
  
  # AIC
  aic_value <- AIC(model)
  
  # BIC
  bic_value <- BIC(model)
  
  # Calculate McFadden's R²
  # First, the null model (intercept only)
  null_model <- tryCatch({
    multinom(party_pref ~ 1, data = data, trace = FALSE)
  }, error = function(e) {
    return(NULL)
  })
  
  mcfadden_r2 <- if(!is.null(null_model)) {
    ll_null <- logLik(null_model)[1]
    1 - (ll / ll_null)
  } else {
    NA
  }
  
  # Number of observations
  n_obs <- nrow(data)
  
  # Degrees of freedom
  df <- attr(logLik(model), "df")
  
  # Return as a named list
  list(
    logLik = ll,
    AIC = aic_value,
    BIC = bic_value,
    McFadden_R2 = mcfadden_r2,
    n_obs = n_obs,
    df = df
  )
}

#==========================================================================================
# FUNCTION TO RUN MULTINOMIAL MODEL FOR SOCIODEMOGRAPHICS
#==========================================================================================

get_sociodem_model_results <- function(survey_name) {
  # Filter data for this survey
  survey_data <- data %>% filter(survey == survey_name)
  
  # Check data sufficiency
  if(nrow(survey_data) < 100) {
    message("Insufficient data for survey ", survey_name, " (n=", nrow(survey_data), ")")
    return(NULL)
  }
  
  party_counts <- table(survey_data$party_pref)
  if(sum(party_counts >= 10) < 3) {
    message("Insufficient party variation for survey ", survey_name)
    return(NULL)
  }
  
  # Report rare categories
  rare_parties <- names(party_counts[party_counts < 20])
  if(length(rare_parties) > 0) {
    message("Warning: Rare party categories in survey ", survey_name, ": ", 
            paste(rare_parties, collapse=", "))
  }
  
  # Fit multinomial model with party_pref as DV and sociodemographic variables as IVs
  model <- tryCatch({
    multinom(party_pref ~ gender + age + edu_binary, data = survey_data, trace = FALSE, maxit = 500)
  }, error = function(e) {
    message("Error in model for survey ", survey_name, ": ", e$message)
    return(NULL)
  })
  
  if(is.null(model)) return(NULL)
  
  # Extract model results using manual approach
  coefs <- summary(model)$coefficients
  ses <- summary(model)$standard.errors
  z_values <- coefs / ses
  p_values <- 2 * (1 - pnorm(abs(z_values)))
  
  # Create a more structured coefficient data frame
  coef_data <- data.frame()
  for(i in 1:nrow(coefs)) {
    party <- rownames(coefs)[i]
    for(j in 1:ncol(coefs)) {
      term <- colnames(coefs)[j]
      estimate <- coefs[i, j]
      std.error <- ses[i, j]
      p.value <- p_values[i, j]
      
      # Add significance stars
      stars <- case_when(
        p.value < 0.001 ~ "***",
        p.value < 0.01 ~ "**",
        p.value < 0.05 ~ "*",
        p.value < 0.1 ~ ".",
        TRUE ~ ""
      )
      
      # Format coefficient with standard error and stars
      coef_with_se <- sprintf("%.3f%s\n(%.3f)", estimate, stars, std.error)
      
      coef_data <- rbind(coef_data, data.frame(
        party = party,
        term = term,
        estimate = estimate,
        std.error = std.error,
        p.value = p.value,
        stars = stars,
        coef_with_se = coef_with_se,
        stringsAsFactors = FALSE
      ))
    }
  }
  
  # Extract log-odds for sociodemographic variables
  gender_log_odds <- coef_data %>%
    filter(term == "genderMale") %>%
    mutate(survey = survey_name, variable = "Gender (Male)") %>%
    select(survey, party, variable, estimate, std.error, stars)
  
  age_log_odds <- coef_data %>%
    filter(term == "age") %>%
    mutate(survey = survey_name, variable = "Age") %>%
    select(survey, party, variable, estimate, std.error, stars)
  
  education_log_odds <- coef_data %>%
    filter(term == "edu_binaryTertiary") %>%
    mutate(survey = survey_name, variable = "Education (Tertiary)") %>%
    select(survey, party, variable, estimate, std.error, stars)
  
  # Combine all sociodemographic log-odds
  sociodem_log_odds <- bind_rows(gender_log_odds, age_log_odds, education_log_odds)
  
  # Calculate model fit statistics
  model_fit <- get_model_fit(model, survey_data)
  
  # Return all results
  list(
    survey = survey_name,
    coefficients = coef_data,
    sociodem_log_odds = sociodem_log_odds,
    n = nrow(survey_data),
    model = model,
    model_fit = model_fit
  )
}

# NEW: Function to run a combined model with all data, using survey as control
get_combined_sociodem_model <- function(data) {
  message("Running combined sociodemographic model with survey as control...")
  
  # Fit multinomial model with survey as control
  model <- tryCatch({
    multinom(party_pref ~ gender + age + edu_binary + survey, 
             data = data, trace = FALSE, maxit = 1000)
  }, error = function(e) {
    message("Error in combined model: ", e$message)
    return(NULL)
  })
  
  if(is.null(model)) return(NULL)
  
  # Extract model results using manual approach
  coefs <- summary(model)$coefficients
  ses <- summary(model)$standard.errors
  z_values <- coefs / ses
  p_values <- 2 * (1 - pnorm(abs(z_values)))
  
  # Create a structured coefficient data frame
  coef_data <- data.frame()
  for(i in 1:nrow(coefs)) {
    party <- rownames(coefs)[i]
    for(j in 1:ncol(coefs)) {
      term <- colnames(coefs)[j]
      estimate <- coefs[i, j]
      std.error <- ses[i, j]
      p.value <- p_values[i, j]
      
      # Add significance stars
      stars <- case_when(
        p.value < 0.001 ~ "***",
        p.value < 0.01 ~ "**",
        p.value < 0.05 ~ "*",
        p.value < 0.1 ~ ".",
        TRUE ~ ""
      )
      
      # Format coefficient with standard error and stars
      coef_with_se <- sprintf("%.3f%s\n(%.3f)", estimate, stars, std.error)
      
      coef_data <- rbind(coef_data, data.frame(
        party = party,
        term = term,
        estimate = estimate,
        std.error = std.error,
        p.value = p.value,
        stars = stars,
        coef_with_se = coef_with_se,
        stringsAsFactors = FALSE
      ))
    }
  }
  
  # Extract log-odds for sociodemographic variables
  gender_log_odds <- coef_data %>%
    filter(term == "genderMale") %>%
    mutate(survey = "Overall", variable = "Gender (Male)") %>%
    select(survey, party, variable, estimate, std.error, stars)
  
  age_log_odds <- coef_data %>%
    filter(term == "age") %>%
    mutate(survey = "Overall", variable = "Age") %>%
    select(survey, party, variable, estimate, std.error, stars)
  
  education_log_odds <- coef_data %>%
    filter(term == "edu_binaryTertiary") %>%
    mutate(survey = "Overall", variable = "Education (Tertiary)") %>%
    select(survey, party, variable, estimate, std.error, stars)
  
  # Combine all sociodemographic log-odds
  sociodem_log_odds <- bind_rows(gender_log_odds, age_log_odds, education_log_odds)
  
  # Calculate model fit statistics
  model_fit <- get_model_fit(model, data)
  
  # Return all results
  list(
    survey = "Overall",
    coefficients = coef_data,
    sociodem_log_odds = sociodem_log_odds,
    n = nrow(data),
    model = model,
    model_fit = model_fit
  )
}

#==========================================================================================
# RUN MODELS FOR EACH SURVEY
#==========================================================================================

# Get results for each survey
unique_surveys <- unique(data$survey)
survey_results <- map(unique_surveys, get_sociodem_model_results)
names(survey_results) <- unique_surveys

# Remove NULL results
valid_results <- Filter(Negate(is.null), survey_results)

# Run the combined model and add to valid_results
combined_result <- get_combined_sociodem_model(data)
if(!is.null(combined_result)) {
  valid_results[["overall"]] <- combined_result
  message("Combined model added to results successfully")
} else {
  message("WARNING: Combined model failed to run")
}

#==========================================================================================
# FIGURE GENERATION - LOG-ODDS
#==========================================================================================

# Extract log-odds for sociodemographic variables from each survey
plot_logodds <- map_dfr(valid_results, function(result) {
  if(is.null(result$sociodem_log_odds)) return(NULL)
  
  result$sociodem_log_odds %>%
    filter(party %in% plot_parties) %>%
    mutate(
      sample_size = result$n,
      lower = estimate - 1.96 * std.error,
      upper = estimate + 1.96 * std.error
    )
})

# Identify extreme outliers
outliers <- plot_logodds %>%
  group_by(variable, party) %>%
  mutate(
    median_est = median(estimate, na.rm = TRUE),
    mad = mad(estimate, na.rm = TRUE),
    is_extreme = abs(estimate - median_est) > 5 * mad
  ) %>%
  filter(is_extreme)

if(nrow(outliers) > 0) {
  message("Identified ", nrow(outliers), " extreme outliers in log-odds data")
  print(outliers %>% select(variable, survey, party, estimate))
  
  # Filter out extreme outliers from plot data
  plot_logodds_clean <- plot_logodds %>%
    anti_join(outliers %>% select(variable, survey, party), by = c("variable", "survey", "party"))
} else {
  plot_logodds_clean <- plot_logodds
}

# Calculate survey sample sizes per party
survey_sizes <- data %>%
  count(survey, party_pref) %>%
  rename(party = party_pref, party_sample_size = n) %>%
  filter(party %in% plot_parties)

# Combine log-odds with sample sizes
plot_logodds_clean <- plot_logodds_clean %>%
  left_join(survey_sizes, by = c("survey", "party"))

# Get overall log-odds for plotting
overall_logodds <- plot_logodds %>%
  filter(survey == "Overall")

# Ensure parties appear in the original party_order
plot_logodds_clean$party <- factor(plot_logodds_clean$party, levels = plot_parties)
overall_logodds$party <- factor(overall_logodds$party, levels = plot_parties)

# Define survey shapes and labels for plotting
survey_shapes <- c(
  "andpol" = 1, "cis_bjul" = 2, "cis_bjun" = 0, "db40_july" = 5, 
  "cis_campaign" = 6, "cis_pos" = 3, "cis_pre" = 4, "db40_june" = 7,
  "gesop" = 8, "ees" = 9, "db40_august" = 10, "Overall" = 16
)

survey_labels <- c(
  "andpol" = "Original survey", "cis_bjul" = "CIS July", "cis_bjun" = "CIS June",
  "db40_july" = "40dB July", "cis_campaign" = "CIS EP Campaign", 
  "cis_pos" = "CIS EP Post-electoral", "cis_pre" = "CIS EP Pre-electoral",
  "db40_june" = "40dB June", "gesop" = "GESOP", "ees" = "EES", 
  "db40_august" = "40dB August", "Overall" = "Overall"
)

# Sociodemographic variables to plot
sociodem_vars <- c("Gender (Male)", "Age", "Education (Tertiary)")

# Create a function to generate a plot for each sociodemographic variable
create_sociodem_plot <- function(var_name) {
  # Filter data for this variable
  var_data <- plot_logodds_clean %>% 
    filter(variable == var_name)
  
  var_overall <- overall_logodds %>% 
    filter(variable == var_name)
  
  # Determine if this is the Education plot (to show legend)
  show_legend <- var_name == "Education (Tertiary)"
  
  # Create the plot
  ggplot() +
    # Overall log-odds with error bars
    geom_point(data = var_overall, 
               aes(x = party, y = estimate, color = party, size = sample_size),
               alpha = 0.7) +
    geom_linerange(data = var_overall, 
                   aes(x = party, ymin = lower, ymax = upper, color = party),
                   size = 1, alpha = 0.7) +
    # Survey-specific log-odds
    geom_point(data = var_data %>% filter(survey != "Overall"), 
               aes(x = party, y = estimate, color = party, size = party_sample_size, shape = survey),
               position = position_jitter(width = 0.2, height = 0), alpha = 0.5) +
    # Value labels for overall estimates
    geom_text(data = var_overall, 
              aes(x = party, y = estimate, label = sprintf("%.2f", estimate)),
              color = "black", vjust = -1, size = 3) +
    # Reference line at 0 (no effect)
    geom_hline(yintercept = 0, linetype = "dashed", color = "gray50") +
    # Styling
    scale_color_manual(values = party_colors, guide = "none") +
    scale_size_continuous(name = "Sample size (N)", 
                          breaks = c(50, 100, 500, 1000, 5000),
                          labels = scales::comma(c(50, 100, 500, 1000, 5000)),
                          range = c(1, 8),
                          guide = if(show_legend) "legend" else "none") +
    scale_shape_manual(name = "Survey", values = survey_shapes, labels = survey_labels,
                       guide = if(show_legend) "legend" else "none") +
    theme_minimal() +
    theme(
      axis.text.x = element_text(angle = 45, hjust = 1),
      legend.position = if(show_legend) "right" else "none",
      panel.grid.minor = element_blank(),
      strip.text = element_text(size = 12, face = "bold")
    ) +
    coord_flip() +
    labs(
      x = "", 
      y = "Log-Odds Coefficient",
      title = var_name,
      caption = if(nrow(outliers) > 0) "Note: Extreme outliers removed for better visualization" else ""
    )
}

# Create plots for each sociodemographic variable
plot_list <- map(sociodem_vars, create_sociodem_plot)
names(plot_list) <- sociodem_vars

# Combine plots into a grid without title or subtitle
combined_plot <- wrap_plots(plot_list, ncol = 1)

# Save the combined plot
ggsave(file.path("outputs/figures", "figure_sociodem_logodds.png"), 
       combined_plot, width = 10, height = 12, bg = "white")

message("Sociodemographic log-odds plots created successfully")

#==========================================================================================
# TABLE GENERATION - LOG-ODDS
#==========================================================================================

# Create a comprehensive table with sociodemographic log-odds for all parties and surveys
logodds_table <- plot_logodds_clean %>%
  # Group by survey and variable to create separate sections
  group_by(survey, variable) %>%
  # Create a formatted log-odds column with stars and standard errors
  mutate(formatted_coef = sprintf("%.3f%s\n(%.3f)", estimate, stars, std.error)) %>%
  # Reshape to wide format with parties as columns
  pivot_wider(
    id_cols = c(survey, variable),
    names_from = party,
    values_from = formatted_coef
  ) %>%
  # Sort by variable and survey
  arrange(variable, survey)

# Move Overall results to the bottom of each variable group
logodds_table <- logodds_table %>%
  arrange(variable, case_when(
    survey == "Overall" ~ "ZZZ",  # Make Overall appear last within each variable group
    TRUE ~ survey
  ))

# If there were outliers, note them in the table footnote
outlier_note <- if(nrow(outliers) > 0) {
  outlier_info <- outliers %>%
    mutate(desc = sprintf("%s for %s (%s: %.2f)", variable, party, survey, estimate)) %>%
    pull(desc) %>%
    paste(collapse = ", ")
  paste("Extreme outliers removed from table for clarity:", outlier_info)
} else {
  ""
}

# Format as LaTeX table
latex_logodds_table <- kable(logodds_table, format = "latex", booktabs = TRUE, 
                             caption = "Table: Log-Odds Coefficients of Sociodemographic Variables on Party Preference",
                             label = "tab:sociodem_logodds",
                             escape = FALSE) %>%
  kable_styling(latex_options = c("striped", "scale_down", "hold_position")) %>%
  column_spec(1:2, bold = FALSE) %>%
  add_header_above(c(" " = 2, "Party (Reference: Abstainer/null)" = length(plot_parties))) %>%
  footnote(
    general = paste("Cell entries show log-odds coefficients from multinomial logistic regression models predicting party preference. Standard errors in parentheses. The 'Overall' rows show coefficients from a combined model with survey as a control variable.", 
                    if(nchar(outlier_note) > 0) paste(" ", outlier_note) else ""),
    symbol = c("p < 0.1", "p < 0.05", "p < 0.01", "p < 0.001"),
    symbol_manual = c(".", "*", "**", "***"),
    threeparttable = TRUE,
    escape = FALSE
  )

# Write to file
writeLines(latex_logodds_table, file.path("outputs/tables", "table_sociodem_logodds.tex"))

#==========================================================================================
# MODEL FIT STATISTICS TABLE
#==========================================================================================

# Create a table with model fit statistics for each survey
model_fit_table <- map_dfr(valid_results, function(result) {
  tibble(
    Survey = result$survey,
    N = result$n,
    LL = sprintf("%.2f", result$model_fit$logLik),
    AIC = sprintf("%.2f", result$model_fit$AIC),
    BIC = sprintf("%.2f", result$model_fit$BIC),
    McFadden_R2 = sprintf("%.3f", result$model_fit$McFadden_R2),
    df = result$model_fit$df
  )
})

# Format as LaTeX table
latex_model_fit <- kable(model_fit_table, format = "latex", booktabs = TRUE,
                         caption = "Table: Model Fit Statistics for Sociodemographic Models by Survey",
                         label = "tab:sociodem_model_fit",
                         col.names = c("Survey", "N", "Log-Likelihood", "AIC", "BIC", "McFadden R²", "df"),
                         align = c("l", "r", "r", "r", "r", "r", "r"),
                         escape = FALSE) %>%
  kable_styling(latex_options = c("striped", "hold_position")) %>%
  footnote(
    general = "All survey-specific models include gender, age, and education as predictors. The 'Overall' model includes the same predictors plus survey as a control variable.",
    threeparttable = TRUE
  )

# Write to file
writeLines(latex_model_fit, file.path("outputs/tables", "table_sociodem_model_fit.tex"))

message("Analysis completed. Output files saved.")
